import numpy as np
import logging
import pathlib
import xml.etree.ElementTree as ET
import cv2
import os


class VOCDataset:

	def __init__(self, root, transform=None, target_transform=None, is_test=False, keep_difficult=False, label_file=None):
		"""Dataset for VOC data.
		Args:
			root: the root of the VOC2007 or VOC2012 dataset, the directory contains the following sub-directories:
				Annotations, ImageSets, JPEGImages, SegmentationClass, SegmentationObject.
		"""
		# self.root = pathlib.Path(root)

		# print(self.root)
		# print("\n")

		# self.root = self.root / "ssd_project/ssd-pytorch-leanh"
		# self.root = './data/VOCdevkit/'

		self.transform = transform
		self.target_transform = target_transform

		if is_test:
			# image_sets_file = self.root / "ImageSets/Main/test.txt"

			# image_sets_file = self.root + "test/VOC2007/ImageSets/Main/test.txt"
			image_sets_file = "./bdd_files/test.txt"
		else:
			# image_sets_file = self.root / "ImageSets/Main/trainval.txt"

			# image_sets_file = self.root + "VOC2007/ImageSets/Main/trainval.txt"
			image_sets_file = "./bdd_files/trainval.txt"


		self.ids = VOCDataset._read_image_ids(image_sets_file)
		self.keep_difficult = keep_difficult

		# if the labels file exists, read in the class names
		# label_file_name = self.root + "labels.txt"

		label_file_name = ""

		if os.path.isfile(label_file_name):
			class_string = ""
			with open(label_file_name, 'r') as infile:
				for line in infile:
					class_string += line.rstrip()

			# classes should be a comma separated list
			
			classes = class_string.split(',')
			# prepend BACKGROUND as first class
			classes.insert(0, 'BACKGROUND')
			classes  = [ elem.replace(" ", "") for elem in classes]
			self.class_names = tuple(classes)
			logging.info("VOC Labels read from file: " + str(self.class_names))

		else:
			logging.info("No labels file, using default VOC classes.")
			# self.class_names = ('BACKGROUND',
			# 'aeroplane', 'bicycle', 'bird', 'boat',
			# 'bottle', 'bus', 'car', 'cat', 'chair',
			# 'cow', 'diningtable', 'dog', 'horse',
			# 'motorbike', 'person', 'pottedplant',
			# 'sheep', 'sofa', 'train', 'tvmonitor')

			self.class_names = ('BACKGROUND', 
				'train', 'truck', 'traffic light', 'traffic sign', 
				'rider', 'person', 'bus', 'bike', 'car', 'motor') 


		self.class_dict = {class_name: i for i, class_name in enumerate(self.class_names)}

	def __getitem__(self, index):
		image_id = self.ids[index]
		boxes, labels, is_difficult = self._get_annotation(image_id)
		if not self.keep_difficult:
			boxes = boxes[is_difficult == 0]
			labels = labels[is_difficult == 0]
		image = self._read_image(image_id)
		if self.transform:
			image, boxes, labels = self.transform(image, boxes, labels)
		if self.target_transform:
			boxes, labels = self.target_transform(boxes, labels)
		return image, boxes, labels

	def get_image(self, index):
		image_id = self.ids[index]
		image = self._read_image(image_id)
		if self.transform:
			image, _ = self.transform(image)
		return image

	def get_annotation(self, index):
		image_id = self.ids[index]
		return image_id, self._get_annotation(image_id)

	def __len__(self):
		return len(self.ids)

	@staticmethod
	def _read_image_ids(image_sets_file):
		ids = []
		with open(image_sets_file) as f:
			for line in f:
				ids.append(line.rstrip())
		return ids

	def _get_annotation(self, image_id):

		# annotation_file = self.root / f"Annotations/{image_id}.xml"

		try:
			annotation_file = "/home/mju-hpc-01/LATran/MindinTech/ssd_project/bdd100k/bdd100k/xml/" + f"train/{image_id}.xml"

			objects = ET.parse(annotation_file).findall("object")

			# print(1)

		except:
			annotation_file = "/home/mju-hpc-01/LATran/MindinTech/ssd_project/bdd100k/bdd100k/xml/" + f"val/{image_id}.xml"

			objects = ET.parse(annotation_file).findall("object")

			# print(2)

		# objects = ET.parse(annotation_file).findall("object")

		boxes = []
		labels = []
		is_difficult = []
		for object in objects:
			class_name = object.find('name').text.lower().strip()
			# we're only concerned with clases in our list
			if class_name in self.class_dict:
				bbox = object.find('bndbox')

				# VOC dataset format follows Matlab, in which indexes start from 0
				x1 = float(bbox.find('xmin').text) - 1
				y1 = float(bbox.find('ymin').text) - 1
				x2 = float(bbox.find('xmax').text) - 1
				y2 = float(bbox.find('ymax').text) - 1
				boxes.append([x1, y1, x2, y2])

				labels.append(self.class_dict[class_name])

				try:
					is_difficult_str = object.find('difficult').text
				except:
					# is_difficult.append(int(is_difficult_str) if is_difficult_str else 0)
					is_difficult.append(0)

		return (np.array(boxes, dtype=np.float32),
				np.array(labels, dtype=np.int64),
				np.array(is_difficult, dtype=np.uint8))

	def _read_image(self, image_id):

		# image_file = self.root / f"JPEGImages/{image_id}.jpg"

		try:
			# image_file = "./data/VOCdevkit/VOC2007/" + f"JPEGImages/{image_id}.jpg"
			image_file = "/home/mju-hpc-01/LATran/MindinTech/ssd_project/bdd100k/bdd100k/images/100k/" + f"train/{image_id}.jpg"

			image = cv2.imread(str(image_file))
			image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
		except:
			# image_file = "./data/VOCdevkit/test/VOC2007/" + f"JPEGImages/{image_id}.jpg"
			image_file = "/home/mju-hpc-01/LATran/MindinTech/ssd_project/bdd100k/bdd100k/images/100k/" + f"val/{image_id}.jpg"

			image = cv2.imread(str(image_file))
			image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

		# image = cv2.imread(str(image_file))
		# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
		return image